// STL
#include <string>  // std::string, std::to_string()
#include <utility> // std::pair<>, std::make_pair()
#include <vector>  // std::vector<>

// jScience
#include "jScience/linalg.hpp" // Matrix<>, Vector<>


// using namespace machine_learning;


//
// DESC: Convert a set of DBN (or an RBM) weights to that useful for a NN.
//
// NOTE: the following is how the matrix is stored for 2 RBMs
// NOTE: ... the extension to more is straightforward
//
//               [   nv0   ]       [   nh1   ]       [   nh2   ]   ...   [   no    ]
// [nv0]   [                                                                             ]
// |   |   |  [               ] [               ] [               ]   [               ]  |
// |   |   |  |       0       | |   nv0 x nh1   | |       0       |   |       0       |  |
// |   |   |  [               ] [               ] [               ]   [               ]  |
// [   ]   |
// [nh1]   |
// |   |   |  [               ] [               ] [               ]   [               ]  |
// |   |   |  | (nv0 x nh1)^T | |       0       | |   nh1 x nh2   |   |       0       |  |
// |   |   |  [               ] [               ] [               ]   [               ]  |
// [   ]   |
// [nh2]   |
// |   |   |  [               ] [               ] [               ]   [               ]  |
// |   |   |  |       0       | | (nh1 x nh2)^T | |       0       |   |   nh2 x no    |  |
// |   |   |  [               ] [               ] [               ]   [               ]  |
// [   ]   |
//   .     |
//   .     |
//   .     |
// [no ]   |
// |   |   |  [               ] [               ] [               ]   [               ]  |
// |   |   |  |       0       | |       0       | | (nh2 x no)^T  |   |       0       |  |
// |   |   |  [               ] [               ] [               ]   [               ]  |
// [   ]   [                                                                             ]
//
// NOTE: the output weights are initialized (randomly) in the same was as jNEURALNET does
//
// NOTE: nodes can have two biases: (i) visible biases in the lower layer (when the nodes are visible units), and (ii) hidden biases in the upper layer; only the latter are stored, because that are what are used in a feedforward network
//
inline std::pair<
                 Matrix<double>, // W
                 Vector<double>  // b
                 > machine_learning::DBN_to_NN_Wb(
                                                  const int                          no,
                                                  const std::vector<Matrix<double>> &W_RBM, // DBN weights
                                                  const std::vector<Vector<double>> &b_RBM  // DBN hidden biases
                                                  )
{
    Matrix<double> W;
    Vector<double> b;

    // determine the size of the total matrix and bias vector
    int n = 0;
    for(auto k = 0; k < W_RBM.size(); ++k)
    {
        n += W_RBM[k].size(0);
    }
    n += W_RBM.back().size(1);
    n += no;

    W = Matrix<double>(n,n, 0.);
    b = Vector<double>(n, 0.);

    //---------------------------------------------------------
    // POPULATE WEIGHT MATRIX AND BIAS VECTOR
    //---------------------------------------------------------

    // store RBMs
    int offset = 0;

    for(auto k = 0; k < W_RBM.size(); ++k)
    {
        for(auto i = 0; i < W_RBM[k].size(0); ++i)
        {
            for(auto j = 0; j < W_RBM[k].size(1); ++j)
            {
                int idx_i = i + offset;
                int idx_j = j + W_RBM[k].size(0) + offset;

                W(idx_i,idx_j) = W(idx_j,idx_i) = W_RBM[k](i,j);
            }
        }

        for(auto i = 0; i < b_RBM[k].size(); ++i)
        {
            int idx_i = (i + W_RBM[k].size(0)) + offset;

            b(idx_i) = b_RBM[k](i);
        }

        offset += W_RBM[k].size(0);
    }

    // add outputs

    // NOTE: the +1 is for the bias node, which I think should be included in the random initialization, since it is an incoming connection that is initialized in the same way
    double alpha = std::sqrt( 3./(W_RBM.back().size(1)+1) );

    for(int k = 0; k < no; ++k)
    {
        int idx_j = k + W_RBM.back().size(1) + offset;

        for(auto i = 0; i < W_RBM.back().size(1); ++i)
        {
            int idx_i = i + offset;

            W(idx_i,idx_j) = W(idx_j,idx_i) = rand_num_uniform_Mersenne_twister(-alpha, alpha);
        }

        b(idx_j) = rand_num_uniform_Mersenne_twister(-alpha, alpha);
    }

    return std::make_pair(
                          W,
                          b
                          );
}

/*
void verify()
{
    int nv0 = 2;
    int nh1 = 3;
    int nh2 = 4;
    int no  = 2;

    std::vector<Matrix<double>> W_RBM(2);
    std::vector<Vector<double>> b_RBM(2);

    W_RBM[0] = Matrix<double>(nv0,nh1);
    b_RBM[0] = Vector<double>((nv0+nh1));

    W_RBM[0](0,0) = 1.2;
    W_RBM[0](0,1) = 0.6;
    W_RBM[0](0,2) = 1.1;
    W_RBM[0](1,0) = 0.2;
    W_RBM[0](1,1) = 1.5;
    W_RBM[0](1,2) = 0.7;

    b_RBM[0](0)   = 0.1;
    b_RBM[0](1)   = 0.7;
    // ---
    b_RBM[0](2)   = 0.4;
    b_RBM[0](3)   = 0.9;
    b_RBM[0](4)   = 0.3;

    W_RBM[1] = Matrix<double>(nh1,nh2);
    b_RBM[1] = Vector<double>((nh1+nh2));

    W_RBM[1](0,0) = 0.7;
    W_RBM[1](0,1) = 0.3;
    W_RBM[1](0,2) = 1.6;
    W_RBM[1](0,3) = 1.3;
    W_RBM[1](1,0) = 0.9;
    W_RBM[1](1,1) = 1.2;
    W_RBM[1](1,2) = 0.6;
    W_RBM[1](1,3) = 1.8;
    W_RBM[1](2,0) = 0.1;
    W_RBM[1](2,1) = 1.2;
    W_RBM[1](2,2) = 0.5;
    W_RBM[1](2,3) = 1.6;

    b_RBM[1](0)   = 0.4;
    b_RBM[1](1)   = 0.8;
    b_RBM[1](2)   = 0.1;
    // ---
    b_RBM[1](3)   = 0.7;
    b_RBM[1](4)   = 0.5;
    b_RBM[1](5)   = 0.8;
    b_RBM[1](6)   = 0.2;

    Matrix<double> W_NN;
    Vector<double> b_NN;
    std::tie(
             W_NN,
             b_NN
             ) = get_NN_Wb(
                           no,
                           W_RBM,
                           b_RBM
                           );

    std::cout << std::fixed;
    std::cout.precision(1);
    for(auto i = 0; i < W_NN.size(0); ++i)
    {
        for(auto j = 0; j < W_NN.size(1); ++j)
        {
            std::cout << W_NN(i,j) << "   ";
        }
        std::cout << '\n';
    }

    std::cout << '\n';

    for(auto i = 0; i < b_NN.size(); ++i)
    {
        std::cout << b_NN(i) << '\n';
    }
}
*/
